Skip to content

Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & add mutlti-node scripts#207

Merged
tpx818 merged 7 commits into
modelscope:mainfrom
meichangsu1:ep_memory_eff_init
Jun 2, 2026
Merged

Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & add mutlti-node scripts#207
tpx818 merged 7 commits into
modelscope:mainfrom
meichangsu1:ep_memory_eff_init

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

@meichangsu1 meichangsu1 commented May 29, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR improves EP + FSDP2 + LoRA training for DeepSeek V4 Flash in multi-node environments.

Main changes:

  1. Optimize Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs.
    Previously, global rank 0 distributed pretrained weights and EP expert shards to all ranks. With large world_size, especially when ep_fsdp_size=1, this put heavy communication pressure on global rank 0.
    This PR changes the loading path so each node's local rank 0 loads/captures the full pretrained state and distributes tensors only to ranks on the same node. EP shard selection still uses rank_to_ep_rank, so each target rank receives the correct EP slice before FSDP/DTensor placement.

  2. Optimize LoRA checkpoint saving.
    LoRA adapter saving no longer calls full-model get_full_state_dict() and then filters LoRA keys. Instead, it collects only LoRA adapter parameters. Native FSDP keeps EP-aware LoRA all-gather, avoiding large base-model state_dict materialization during adapter checkpoint saving.

  3. Remove the Twinkle-side Accelerate FSDP2 state-dict loading monkey patch.
    AccelerateStrategy now relies on native Accelerate behavior for memory_efficient_init / cpu_ram_efficient_loading.

  4. Add multi-node DeepSeek V4 Flash EP + FSDP2 + LoRA cookbook script.
    The DeepSeek V4 LoRA cookbook now supports configuring GPU/NPU count via NUM_GPUS.

Experiment results

[2026-06-01 07:35:49][INFO:twinkle] Current is step 4 of 16, metric: {'loss': '3.0792', 'learning rate(param group 1)': '0.000000e+00', 'learning rate(param group 2)': '0.000000e+00', 'iters': 0, 'total time elapse': '2.9 minutes', 'speed': '0.00 iters/s'}
[2026-06-01 07:36:16][INFO:twinkle] Current is step 8 of 16, metric: {'loss': '2.9792', 'grad_norm': '135.397873', 'learning rate(param group 1)': '2.000000e-05', 'learning rate(param group 2)': '2.000000e-05', 'iters': 1, 'total time elapse': '200 seconds', 'speed': '0.04 iters/s'}
[2026-06-01 07:36:51][INFO:twinkle] Current is step 12 of 16, metric: {'loss': '3.0201', 'grad_norm': '136.174759', 'learning rate(param group 1)': '4.000000e-05', 'learning rate(param group 2)': '4.000000e-05', 'iters': 2, 'total time elapse': '236 seconds', 'speed': '0.03 iters/s'}
[2026-06-01 07:37:22][INFO:twinkle] Saved final adapter to /nas/diskz/checkpoint-final

qq_30035749 added 3 commits May 29, 2026 15:47
Load full pretrained weights on each node's local rank0 and distribute shards
only within the node, reducing global rank0 pressure for large EP/FSDP jobs.
@meichangsu1 meichangsu1 marked this pull request as draft May 29, 2026 10:45
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the FSDP2 state-dict loading patch from the Accelerate strategy and implements node-local state-dict loading and scattering in the Native FSDP strategy using local rank topology. The review feedback suggests optimizing the node-local communication by creating a node-local process group once and using dist.broadcast instead of inefficient point-to-point dist.send and dist.recv calls for every parameter. Additionally, it is recommended to verify that distributed training is initialized before retrieving the local rank to prevent runtime errors in non-distributed environments.

Comment thread src/twinkle/model/transformers/strategy/native_fsdp.py
Comment thread src/twinkle/model/transformers/strategy/native_fsdp.py
Comment thread src/twinkle/model/transformers/strategy/native_fsdp.py
Implement `get_adapter_state_dict` methods in AccelerateStrategy and NativeFSDPStrategy to efficiently collect only LoRA adapter parameters, avoiding full model state dict collection for large FSDP/EP jobs. The NativeFSDP version includes EP-aware all-gather for expert parameters.
@meichangsu1 meichangsu1 changed the title Ep memory eff init Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & checkpoint-saving logic Jun 1, 2026
qq_30035749 added 3 commits June 1, 2026 16:03
The previous implementation used individual send/recv operations for each target rank, which was inefficient and could cause performance bottlenecks. This change replaces them with a single broadcast call using a new local group, improving communication efficiency and reducing code complexity.
Avoid HCCL subgroup broadcast for node-local weight loading, since dynamic
subgroup communicators can fail on NPU. Fall back to send/recv on HCCL while
keeping local broadcast for other backends.
@meichangsu1 meichangsu1 changed the title Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & checkpoint-saving logic Optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs & add mutlti-node scripts Jun 1, 2026
@tpx818 tpx818 marked this pull request as ready for review June 1, 2026 10:33
Comment thread src/twinkle/model/transformers/strategy/accelerate.py
@tpx818 tpx818 merged commit b14e67a into modelscope:main Jun 2, 2026
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants